'''
Author: SlytherinGe
LastEditTime: 2021-07-09 14:42:01
'''
import torch
import numpy as np
import os
import time
import sys
import xml.etree.ElementTree as ET
import cv2

import matplotlib.pyplot as plt

from aoi_box_generate import AOI_Generator

def get_mmdet_data_from_anno(anno):
    bboxes = []
    for bndbox in anno['object']:
        bbox = (int(bndbox['bndbox']['xmin']), int(bndbox['bndbox']['ymin']), int(bndbox['bndbox']['xmax']), int(bndbox['bndbox']['ymax']))
        bboxes.append(bbox)
    bboxes = np.array(bboxes)
    shape = (int(anno['size']['height']), int(anno['size']['width']), int(anno['size']['depth']))
    return dict(gt_bboxes=bboxes, img_shape=shape)


SMALL_NUM = 1e-12

class AOIMaskGenerator(object):

    def __init__(self,
                 sigma,
                 pos_thr=1.0,
                 sigma_ratio=1.0):
        super().__init__()

        self.sigma = sigma
        self.pos_thr = pos_thr
        self.sigma_ratio = sigma_ratio

    def __call__(self, results):
        # get bboxes data from results['gt_bboxes'], which is a numpy array
        # get img shape from results['img_shape'], which is a tuple(h, w, c)
        gt_bboxes = results['gt_bboxes'].astype(np.intp)
        img_shape = results['img_shape']
        if self.sigma is None:
            sigmax = np.ones((gt_bboxes.shape[0], 2), dtype=np.int)
            sigmay = np.ones((gt_bboxes.shape[0], 2), dtype=np.int)
            sigmax[:,0] = (gt_bboxes[:,2]-gt_bboxes[:,0])*self.sigma_ratio
            sigmay[:,0] = (gt_bboxes[:,3]-gt_bboxes[:,1])*self.sigma_ratio
            sigmax = np.max(sigmax, axis=-1)
            sigmay = np.max(sigmay, axis=-1)

        background = np.zeros((gt_bboxes.shape[0], img_shape[0], img_shape[1]))
        for i in range(gt_bboxes.shape[0]):
            background[i,gt_bboxes[i,1]:gt_bboxes[i,3],gt_bboxes[i,0]:gt_bboxes[i,2]] = 1.0
            if self.sigma is not None:
                background[i] = cv2.GaussianBlur(background[i],(0,0),self.sigma)
            else:
                background[i] = cv2.GaussianBlur(background[i],(0,0),sigmaX=int(sigmax[i]),sigmaY=int(sigmay[i]))
            background[i] /= (np.max(background[i]) + SMALL_NUM)
        heat_map = np.sum(background,axis=0) 
        heat_map /= self.pos_thr
        heat_map[heat_map>1] = 1.0

        results['aoi_heatmap'] = heat_map
        return results        

class CenternessMaskGenerator(object):

    def __init__(self, sigma=30):
        super().__init__()
        self.sigma = sigma

    def __call__(self, results):
        # get bboxes data from results['gt_bboxes'], which is a numpy array
        # get img shape from results['img_shape'], which is a tuple(h, w, c)
        gt_bboxes = results['gt_bboxes']
        img_shape = results['img_shape']
        background = np.zeros((gt_bboxes.shape[0], img_shape[0], img_shape[1]))
        for i in range(gt_bboxes.shape[0]):

            background[i,gt_bboxes[i,1]:gt_bboxes[i,3],gt_bboxes[i,0]:gt_bboxes[i,2]] = 1.0
            background[i] = cv2.GaussianBlur(background[i],(0,0),self.sigma)
            background[i] /= np.max(background[i])
        heat_map = np.sum(background,axis=0)
        heat_map[heat_map>1.0] = 1.0
        results['aoi_heatmap'] = heat_map
        return results

if __name__ == '__main__':

    ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/Annotations/'
    IMG_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/'
    OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/develop/temp/big_anno'
    IDX = 69
    # IDX = 120
    SIGMA = 30
    ag = AOI_Generator(ANNO_ROOT, OUT_ROOT)
    mask_generator = AOIMaskGenerator( sigma=None, pos_thr=0.9, sigma_ratio=0.5)
    # for IDX in range(1160):

    bbox = ag.get_anno_from_index(IDX)
    det_data = get_mmdet_data_from_anno(bbox)


    det_data = mask_generator(det_data)

    det_data['aoi_heatmap'][det_data['aoi_heatmap'] >= 1] = 0

    sar_img = np.uint16(cv2.imread(os.path.join(IMG_ROOT, ag.anno_files[IDX].split('.')[0]+'.jpg')) * 0.5)

    blank = np.zeros_like(sar_img)

    blank += sar_img

    blank[:,:,0] += np.uint16(128* det_data['aoi_heatmap'])
    blank[:,:,1] += np.uint16(128* det_data['aoi_heatmap'])

    blank[blank>255]=255
    plt.figure()
    plt.imshow(blank)
    # plt.title('sigma at {}'.format(SIGMA))
    plt.axis('off')
    plt.show()

    print(ag.get_anno_from_index(IDX))
    print(IDX+1 ,'/1160')

